{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Environment\n",
    "\n",
    "Prerequisite:\n",
    "\n",
    "- PyTorch (tested on 1.8.1/1.10.1)\n",
    "- Pyro (tested on 1.6.0)\n",
    "\n",
    "We recommend PyTorch 1.8.1, on which the current implementation of the PnP solver runs significantly faster than PyTorch 1.10.1.\n",
    "\n",
    "Install the python packages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# CUDA 11.1\n",
    "%pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html\n",
    "%pip install pyro-ppl==1.6.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clone and enter this project:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/tjiiv-cprg/EPro-PnP\n",
    "%cd EPro-PnP"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fit the Identity Function\n",
    "\n",
    "Here we demonstrate the usage of EPro-PnP by fitting a simple model `out_pose = EProPnP(MLP(in_pose))` to data points generated from the identity function $I: SE(3) \\to SE(3)$. The model takes `in_pose = [x, y, z, w, i, j, k]` as input, which is converted into a 2D-3D correspondence set by a plain MLP (only for demonstration, in practice it makes no sense to use an MLP to predict the full correspondence set), and outputs the probabilistic pose through the EPro-PnP layer. 65536 data points are generated with additional noise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.data as Data\n",
    "\n",
    "from epropnp.epropnp import EProPnP6DoF\n",
    "from epropnp.levenberg_marquardt import LMSolver, RSLMSolver\n",
    "from epropnp.camera import PerspectiveCamera\n",
    "from epropnp.cost_fun import AdaptiveHuberPnPCost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "n_data = 65536\n",
    "batch_size = 256\n",
    "n_epoch = 10\n",
    "noise = 0.01"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model(nn.Module):\n",
    "\n",
    "    def __init__(\n",
    "            self,\n",
    "            num_points=64,  # number of 2D-3D pairs\n",
    "            mlp_layers=[1024],  # a single hidden layer\n",
    "            epropnp=EProPnP6DoF(\n",
    "                mc_samples=512,\n",
    "                num_iter=4,\n",
    "                solver=LMSolver(\n",
    "                    dof=6,\n",
    "                    num_iter=10,\n",
    "                    init_solver=RSLMSolver(\n",
    "                        dof=6,\n",
    "                        num_points=8,\n",
    "                        num_proposals=128,\n",
    "                        num_iter=5))),\n",
    "            camera=PerspectiveCamera(),\n",
    "            cost_fun=AdaptiveHuberPnPCost(\n",
    "                relative_delta=0.5)):\n",
    "        super().__init__()\n",
    "        self.num_points = num_points\n",
    "        mlp_layers = [7] + mlp_layers\n",
    "        mlp = []\n",
    "        for i in range(len(mlp_layers) - 1):\n",
    "            mlp.append(nn.Linear(mlp_layers[i], mlp_layers[i + 1]))\n",
    "            mlp.append(nn.LeakyReLU())\n",
    "        mlp.append(nn.Linear(mlp_layers[-1], num_points * (3 + 2 + 2)))\n",
    "        self.mlp = nn.Sequential(*mlp)\n",
    "        # Here we use static weight_scale because the data noise is homoscedastic\n",
    "        self.log_weight_scale = nn.Parameter(torch.zeros(2))\n",
    "        self.epropnp = epropnp\n",
    "        self.camera = camera\n",
    "        self.cost_fun = cost_fun\n",
    "\n",
    "    def forward_correspondence(self, in_pose):\n",
    "        x3d, x2d, w2d = self.mlp(in_pose).reshape(-1, self.num_points, 7).split([3, 2, 2], dim=-1)\n",
    "        w2d = (w2d.log_softmax(dim=-2) + self.log_weight_scale).exp()\n",
    "        # equivalant to:\n",
    "        #     w2d = w2d.softmax(dim=-2) * self.log_weight_scale.exp()\n",
    "        return x3d, x2d, w2d\n",
    "\n",
    "    def forward_train(self, in_pose, cam_mats, out_pose):\n",
    "        x3d, x2d, w2d = self.forward_correspondence(in_pose)\n",
    "        self.camera.set_param(cam_mats)\n",
    "        self.cost_fun.set_param(x2d.detach(), w2d)  # compute dynamic delta\n",
    "        pose_opt, cost, pose_opt_plus, pose_samples, pose_sample_logweights, cost_tgt = self.epropnp.monte_carlo_forward(\n",
    "            x3d,\n",
    "            x2d,\n",
    "            w2d,\n",
    "            self.camera,\n",
    "            self.cost_fun,\n",
    "            pose_init=out_pose,\n",
    "            force_init_solve=True,\n",
    "            with_pose_opt_plus=True)  # True for derivative regularization loss\n",
    "        norm_factor = model.log_weight_scale.detach().exp().mean()\n",
    "        return pose_opt, cost, pose_opt_plus, pose_samples, pose_sample_logweights, cost_tgt, norm_factor\n",
    "\n",
    "    def forward_test(self, in_pose, cam_mats, fast_mode=False):\n",
    "        x3d, x2d, w2d = self.forward_correspondence(in_pose)\n",
    "        self.camera.set_param(cam_mats)\n",
    "        self.cost_fun.set_param(x2d.detach(), w2d)\n",
    "        # returns a mode of the distribution\n",
    "        pose_opt, _, _, _ = self.epropnp(\n",
    "            x3d, x2d, w2d, self.camera, self.cost_fun,\n",
    "            fast_mode=fast_mode)  # fast_mode=True activates Gauss-Newton solver (no trust region)\n",
    "        return pose_opt\n",
    "        # or returns weighted samples drawn from the distribution (slower):\n",
    "        #     _, _, _, pose_samples, pose_sample_logweights, _ = self.epropnp.monte_carlo_forward(\n",
    "        #         x3d, x2d, w2d, self.camera, self.cost_fun, fast_mode=fast_mode)\n",
    "        #     pose_sample_weights = pose_sample_logweights.softmax(dim=0)\n",
    "        #     return pose_samples, pose_sample_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MonteCarloPoseLoss(nn.Module):\n",
    "\n",
    "    def __init__(self, init_norm_factor=1.0, momentum=0.1):\n",
    "        super(MonteCarloPoseLoss, self).__init__()\n",
    "        self.register_buffer('norm_factor', torch.tensor(init_norm_factor, dtype=torch.float))\n",
    "        self.momentum = momentum\n",
    "\n",
    "    def forward(self, pose_sample_logweights, cost_target, norm_factor):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            pose_sample_logweights: Shape (mc_samples, num_obj)\n",
    "            cost_target: Shape (num_obj, )\n",
    "            norm_factor: Shape ()\n",
    "        \"\"\"\n",
    "        if self.training:\n",
    "            with torch.no_grad():\n",
    "                self.norm_factor.mul_(\n",
    "                    1 - self.momentum).add_(self.momentum * norm_factor)\n",
    "\n",
    "        loss_tgt = cost_target\n",
    "        loss_pred = torch.logsumexp(pose_sample_logweights, dim=0)  # (num_obj, )\n",
    "\n",
    "        loss_pose = loss_tgt + loss_pred  # (num_obj, )\n",
    "        loss_pose[torch.isnan(loss_pose)] = 0\n",
    "        loss_pose = loss_pose.mean() / self.norm_factor\n",
    "\n",
    "        return loss_pose.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate data points\n",
    "in_pose = torch.randn([n_data, 7], device=device)\n",
    "in_pose[:, 2] += 5  # positive z, avoid points falling behind the camera plane\n",
    "in_pose[:, 3:] = F.normalize(in_pose[:, 3:], dim=-1)  # normalize to unit quaternion\n",
    "\n",
    "out_pose = in_pose + torch.randn([n_data, 7], device=device) * noise\n",
    "out_pose[:, 3:] = F.normalize(out_pose[:, 3:], dim=-1)  # normalize to unit quaternion\n",
    "\n",
    "cam_mats = torch.eye(3, device=device)\n",
    "\n",
    "dataset = Data.TensorDataset(in_pose, out_pose)\n",
    "loader = Data.DataLoader(\n",
    "    dataset=dataset,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True)\n",
    "\n",
    "# setup model\n",
    "model = Model().to(device)\n",
    "mc_loss_fun = MonteCarloPoseLoss().to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam([\n",
    "                {'params': model.mlp.parameters()},\n",
    "                {'params': model.log_weight_scale, 'lr': 1e-2}\n",
    "            ], lr=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# start training\n",
    "for epoch_id in range(n_epoch):\n",
    "    for iter_id, (batch_in_pose, batch_out_pose) in enumerate(loader):  # for each training step\n",
    "        batch_cam_mats = cam_mats.expand(batch_in_pose.size(0), -1, -1)\n",
    "        _, _, pose_opt_plus, _, pose_sample_logweights, cost_tgt, norm_factor = model.forward_train(\n",
    "            batch_in_pose,\n",
    "            batch_cam_mats,\n",
    "            batch_out_pose)\n",
    "\n",
    "        # monte carlo pose loss\n",
    "        loss_mc = mc_loss_fun(\n",
    "            pose_sample_logweights,\n",
    "            cost_tgt, \n",
    "            norm_factor)\n",
    "\n",
    "        # derivative regularization\n",
    "        dist_t = (pose_opt_plus[:, :3] - batch_out_pose[:, :3]).norm(dim=-1)\n",
    "        beta = 1.0\n",
    "        loss_t = torch.where(dist_t < beta, 0.5 * dist_t.square() / beta,\n",
    "                             dist_t - 0.5 * beta)\n",
    "        loss_t = loss_t.mean()\n",
    "\n",
    "        dot_quat = (pose_opt_plus[:, None, 3:] @ batch_out_pose[:, 3:, None]).squeeze(-1).squeeze(-1)\n",
    "        loss_r = (1 - dot_quat.square()) * 2\n",
    "        loss_r = loss_r.mean()\n",
    "\n",
    "        loss = loss_mc + 0.1 * loss_t + 0.1 * loss_r\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "\n",
    "        grad_norm = []\n",
    "        for p in model.parameters():\n",
    "            if (p.grad is None) or (not p.requires_grad):\n",
    "                continue\n",
    "            else:\n",
    "                grad_norm.append(torch.norm(p.grad.detach()))\n",
    "        grad_norm = torch.norm(torch.stack(grad_norm))\n",
    "        \n",
    "        optimizer.step()\n",
    "\n",
    "        print('Epoch {}: {}/{} - loss_mc={:.4f}, loss_t={:.4f}, loss_r={:.4f}, loss={:.4f}, norm_factor={:.4f}, grad_norm={:.4f}'.format(\n",
    "            epoch_id + 1, iter_id + 1, len(loader), loss_mc, loss_t, loss_r, loss, norm_factor, grad_norm))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate test data\n",
    "batch_in_pose_test = torch.randn([1024, 7], device=device)\n",
    "batch_in_pose_test[:, 2] += 5\n",
    "batch_in_pose_test[:, 3:] = F.normalize(batch_in_pose_test[:, 3:], dim=-1)\n",
    "batch_cam_mats = cam_mats.expand(batch_in_pose_test.size(0), -1, -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# batch inference\n",
    "with torch.no_grad():\n",
    "    pose_opt = model.forward_test(batch_in_pose_test, batch_cam_mats)\n",
    "\n",
    "# evaluation\n",
    "dist_t = (pose_opt[:, :3] - batch_in_pose_test[:, :3]).norm(dim=-1)\n",
    "dot_quat = (pose_opt[:, None, 3:] @ batch_in_pose_test[:, 3:, None]).squeeze(-1).squeeze(-1)\n",
    "dist_theta = 2 * torch.acos(dot_quat.abs())\n",
    "print('Mean Translation Error: {:4f}'.format(dist_t.mean()))\n",
    "print('Mean Orientation Error: {:4f}'.format(dist_theta.mean()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "EPro-PnP-rel",
   "language": "python",
   "name": "epro-pnp-rel"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}